import operator
import os

import numpy as np
import pandas as pd
from tqdm import tqdm

from generic.data_util import print_game_events_info, read_feature_mean_scale, Transition, \
    reverse_standard_data, ICEHOCKEY_ACTIONS, reward_look_ahead, \
    summarize_location_bin_samples, normalization_01, summarize_feature_bin_samples, divide_dataset_according2date, \
    compute_distance_by_teams, aggregate_compute_distance_by_team, SOCCER_ACTIONS, get_game_time, read_player_info, \
    read_data
from sklearn.manifold import TSNE
from density_model.maf_model_bak import validate_maf
from generic.model_util import to_np, get_distrib_q_model_save_path
from generic.plot_util import plot_curve, plot_shadow_curve, plot_heatmap, plot_histogram, plot_scatter


def contextualized_empirical_risk_measure(agent,
                                          model_label,
                                          episode_num,
                                          sports,
                                          target_action='all',
                                          mode='valid',
                                          # distance_measure='mse',
                                          sanity_check_msg=None,
                                          debug_mode=False,
                                          uncertainty_model='gda',
                                          # verbose=False
                                          ):
    # if debug_mode:
    #     agent.train_rate = 0.8

    # if_sanity_check = True if 'sanity' in model_label else False
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, valid_files, testing_files = \
        divide_dataset_according2date(all_data_files=all_files,
                                      train_rate=agent.train_rate,
                                      sports=sports,
                                      if_split=agent.apply_data_date_div)

    if sports == 'ice-hockey':
        interested_feature_names = ['Penalty', 'scoreDifferential', 'time_remained']
        interested_feature_split_dict = {
            'Penalty': [-float('inf'), -0.5, 0.5, float('inf')],
            'scoreDifferential': [-float('inf'), -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, float('inf')],
            'time_remained': [-float('inf'), 1200, 2400, float('inf')]
        }
    elif sports == 'soccer':
        interested_feature_names = ['manPower', 'scoreDiff', 'gameTimeRemain']
        interested_feature_split_dict = {
            'manPower': [-float('inf'), -0.5, 0.5, float('inf')],
            'scoreDiff': [-float('inf'), -2.5, -1.5, -0.5, 0.5, 1.5, 2.5, float('inf')],
            'gameTimeRemain': [-float('inf'), 45, float('inf')]
        }
    else:
        raise ValueError("Unknown sports: {0}".format(sports))

    if mode == 'train':
        game_files = training_files
    elif mode == 'valid':
        game_files = valid_files
    else:
        game_files = testing_files

    if debug_mode:
        game_files = game_files[:2]

    action_home_samples_dict = {}
    action_home_outcome_dict = {}
    action_away_samples_dict = {}
    action_away_outcome_dict = {}
    action_home_uncertainty_dict = {}
    action_away_uncertainty_dict = {}
    # from tqdm import tqdm
    for i in range(len(game_files)):
        game_name = game_files[i]
        action_home_samples_dict = \
            compute_action_output_with_features(agent=agent,
                                                game_name=game_name,
                                                is_home=1,
                                                action_samples_dict=action_home_samples_dict,
                                                interested_feature_names=interested_feature_names,
                                                sanity_check_msg=sanity_check_msg,
                                                sports=sports)
        action_home_outcome_dict = \
            empirical_action_output_with_features(agent=agent,
                                                  game_name=game_name,
                                                  is_home=1,
                                                  action_outcome_dict=action_home_outcome_dict,
                                                  interested_feature_names=interested_feature_names,
                                                  sports=sports)
        action_away_samples_dict = \
            compute_action_output_with_features(agent=agent,
                                                game_name=game_name,
                                                is_home=0,
                                                action_samples_dict=action_away_samples_dict,
                                                interested_feature_names=interested_feature_names,
                                                sanity_check_msg=sanity_check_msg,
                                                sports=sports)
        action_away_outcome_dict = \
            empirical_action_output_with_features(agent=agent,
                                                  game_name=game_name,
                                                  is_home=0,
                                                  action_outcome_dict=action_away_outcome_dict,
                                                  interested_feature_names=interested_feature_names,
                                                  sports=sports
                                                  )
        if uncertainty_model is not None:
            if sports == 'ice-hockey':
                interested_location_feature_names = ('xAdjCoord', 'yAdjCoord')
            elif sports == 'soccer':
                interested_location_feature_names = ('x', 'y')
            else:
                raise ValueError("Unknown sports: {0}".format(sports))
            action_home_uncertainty_dict = \
                compute_action_output_with_features(agent=agent,
                                                    game_name=game_name,
                                                    is_home=1,
                                                    sports=sports,
                                                    interested_feature_names=interested_location_feature_names,
                                                    action_samples_dict=action_home_uncertainty_dict,
                                                    sanity_check_msg=sanity_check_msg,
                                                    output_type='Uncertainty',
                                                    uncertainty_model=uncertainty_model)

            action_away_uncertainty_dict = \
                compute_action_output_with_features(agent=agent,
                                                    game_name=game_name,
                                                    is_home=0,
                                                    sports=sports,
                                                    interested_feature_names=interested_location_feature_names,
                                                    action_samples_dict=action_away_uncertainty_dict,
                                                    sanity_check_msg=sanity_check_msg,
                                                    output_type='Uncertainty',
                                                    uncertainty_model=uncertainty_model)

    bar_plot_data = {}
    if uncertainty_model is None:
        agent.uncertainty_thresholds = [float('inf')]
    for uncertainty_threshold in agent.uncertainty_thresholds:
        sample_home_location_data = []
        outcome_home_location_data = []
        sample_away_location_data = []
        outcome_away_location_data = []
        if target_action == 'all':
            all_home_action = action_home_samples_dict.keys()
            all_away_action = action_away_samples_dict.keys()
        else:
            all_home_action = [target_action]
            all_away_action = [target_action]

        for action in all_home_action:
            for _index in range(len(action_home_samples_dict[action])):
                if uncertainty_model is None or action_home_uncertainty_dict[action][_index][
                    0] <= uncertainty_threshold:
                    sample_home_location_data.append(action_home_samples_dict[action][_index])
                    outcome_home_location_data.append(action_home_outcome_dict[action][_index])
        for action in all_away_action:
            for _index in range(len(action_away_samples_dict[action])):
                if uncertainty_model is None or action_away_uncertainty_dict[action][_index][
                    0] <= uncertainty_threshold:
                    sample_away_location_data.append(action_away_samples_dict[action][_index])
                    outcome_away_location_data.append(action_away_outcome_dict[action][_index])

        sample_home_bin_feature_value_store, sample_home_feature_label_meanings = \
            summarize_feature_bin_samples(
                feature_exp_values=[[item[0][0], item[2]] for item in sample_home_location_data],
                interested_feature_names=interested_feature_names,
                interested_feature_split_dict=interested_feature_split_dict)
        output_home_bin_feature_value_store, output_home_feature_label_meanings = \
            summarize_feature_bin_samples(
                feature_exp_values=[[np.asarray([item[0]]), item[3:]] for item in outcome_home_location_data],
                interested_feature_names=interested_feature_names,
                interested_feature_split_dict=interested_feature_split_dict)
        sample_away_bin_feature_value_store, sample_away_feature_label_meanings = \
            summarize_feature_bin_samples(
                feature_exp_values=[[item[0][1], item[2]] for item in sample_away_location_data],
                interested_feature_names=interested_feature_names,
                interested_feature_split_dict=interested_feature_split_dict)
        output_away_bin_feature_value_store, output_away_feature_label_meanings = \
            summarize_feature_bin_samples(
                feature_exp_values=[[np.asarray([item[1]]), item[3:]] for item in outcome_away_location_data],
                interested_feature_names=interested_feature_names,
                interested_feature_split_dict=interested_feature_split_dict)

        difference_all = []
        bar_plot_uncertainty_label = []
        bar_plot_uncertainty_data = []
        all_record_msg = ''
        for name_idx in range(len(interested_feature_names)):
            for value_idx in range(len(interested_feature_split_dict[interested_feature_names[name_idx]]) - 1):
                feature_name = interested_feature_names[name_idx]
                feature_value_1 = interested_feature_split_dict[interested_feature_names[name_idx]][value_idx]
                feature_value_2 = interested_feature_split_dict[interested_feature_names[name_idx]][value_idx + 1]

                interest_feature_value_store, interest_feature_output_store = \
                    aggregate_compute_distance_by_team(
                        sample_home_bin_fea_value_store=sample_home_bin_feature_value_store,
                        output_home_bin_fea_value_store=output_home_bin_feature_value_store,
                        sample_home_fea_label_meanings=sample_home_feature_label_meanings,
                        sample_away_bin_fea_value_store=sample_away_bin_feature_value_store,
                        output_away_bin_fea_value_store=output_away_bin_feature_value_store,
                        sample_away_fea_label_meanings=sample_away_feature_label_meanings,
                        interest_fea_label_pair=[name_idx, value_idx]
                    )

                bar_plot_uncertainty_label.append('{0}<{1}<{2}'.format(feature_value_1, feature_name, feature_value_2))

                if len(interest_feature_value_store) > 0 and len(interest_feature_output_store) > 0:
                    difference = np.abs(np.mean(interest_feature_value_store) - np.mean(interest_feature_output_store))
                    difference_all.append(difference)
                    record_msg = "{0}<{1}<{2}, difference: {3} ". \
                        format(feature_value_1, feature_name, feature_value_2, difference)
                    all_record_msg += record_msg + '\n'
                    bar_plot_uncertainty_data.append(difference)
                else:
                    bar_plot_uncertainty_data.append(float('inf'))
        bar_plot_data.update({uncertainty_threshold: bar_plot_uncertainty_data})

        if debug_mode:
            print(uncertainty_threshold)
            print(all_record_msg)

        if not os.path.exists('./empirical_results_for_calibration/' + model_label):
            os.mkdir('./empirical_results_for_calibration/' + model_label)

        with open('./empirical_results_for_calibration/' + model_label +
                  '/calibration_results' + '_uncertain_{0}_epi-{1}.txt'.format(uncertainty_threshold, episode_num),
                  'w') as record_file:
            record_file.write(all_record_msg)

        # return np.mean(difference_all), all_record_msg

    import matplotlib.pyplot as plt
    # fig = plt.figure(figsize=(10, 10))
    difference_matrix_bar_plot = pd.DataFrame(bar_plot_data)
    width = 0.3  # width of a ba
    difference_matrix_bar_plot[agent.uncertainty_thresholds].plot(kind='bar', width=width, figsize=(25, 12))

    ax = plt.gca()
    plt.xlim([-width, len(difference_matrix_bar_plot[float('inf')]) - width])
    y_pos = range(len(bar_plot_uncertainty_label))
    plt.xticks(y_pos, bar_plot_uncertainty_label, rotation=45, fontsize=16)
    # plt.show()
    bar_plot_label = './empirical_results_for_calibration/' + model_label + '/difference_bar_plot_epi-{0}.png'.format(
        episode_num)
    plt.savefig(bar_plot_label)

    # predicted_exp_home_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                             len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                             len(interested_feature_split_dict['time_remained']) - 1])
    # outcome_exp_home_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                           len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                           len(interested_feature_split_dict['time_remained']) - 1])
    # predicted_exp_away_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                             len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                             len(interested_feature_split_dict['time_remained']) - 1])
    # outcome_exp_away_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                           len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                           len(interested_feature_split_dict['time_remained']) - 1])
    #
    # predicted_std_home_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                             len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                             len(interested_feature_split_dict['time_remained']) - 1])
    # outcome_std_home_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                           len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                           len(interested_feature_split_dict['time_remained']) - 1])
    # predicted_std_away_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                             len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                             len(interested_feature_split_dict['time_remained']) - 1])
    # outcome_std_away_value_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                                           len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                                           len(interested_feature_split_dict['time_remained']) - 1])
    #
    # home_num_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                             len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                             len(interested_feature_split_dict['time_remained']) - 1])
    # away_num_matrix = np.zeros([len(interested_feature_split_dict['Penalty']) - 1,
    #                             len(interested_feature_split_dict['scoreDifferential']) - 1,
    #                             len(interested_feature_split_dict['time_remained']) - 1])

    # for label in sample_home_bin_feature_value_store.keys():
    #     i, j, k = list(map(int, label.split('@')))
    #
    #     predicted_exp_home_value_matrix[i, j, k] = np.mean(sample_home_bin_feature_value_store[label])
    #     outcome_exp_home_value_matrix[i, j, k] = np.mean(output_home_bin_feature_value_store[label])
    #     home_num_matrix[i, j, k] = len(sample_home_bin_feature_value_store[label])
    #     predicted_std_home_value_matrix[i, j, k] = np.std(sample_home_bin_feature_value_store[label])
    #     outcome_std_home_value_matrix[i, j, k] = np.std(output_home_bin_feature_value_store[label])
    #
    #     if label in sample_away_bin_feature_value_store:
    #         predicted_exp_away_value_matrix[i, j, k] = np.mean(sample_away_bin_feature_value_store[label])
    #         outcome_exp_away_value_matrix[i, j, k] = np.mean(output_away_bin_feature_value_store[label])
    #         away_num_matrix[i, j, k] = len(sample_away_bin_feature_value_store[label])
    #         predicted_std_away_value_matrix[i, j, k] = np.std(sample_away_bin_feature_value_store[label])
    #         outcome_std_away_value_matrix[i, j, k] = np.std(output_away_bin_feature_value_store[label])
    #
    # for label in sample_away_bin_feature_value_store.keys():
    #     i, j, k = list(map(int, label.split('@')))
    #     if label in sample_home_bin_feature_value_store:
    #         predicted_exp_home_value_matrix[i, j, k] = np.mean(sample_home_bin_feature_value_store[label])
    #         outcome_exp_home_value_matrix[i, j, k] = np.mean(output_home_bin_feature_value_store[label])
    #         home_num_matrix[i, j, k] = len(sample_home_bin_feature_value_store[label])
    #         predicted_std_home_value_matrix[i, j, k] = np.std(sample_home_bin_feature_value_store[label])
    #         outcome_std_home_value_matrix[i, j, k] = np.std(output_home_bin_feature_value_store[label])
    #     predicted_exp_away_value_matrix[i, j, k] = np.mean(sample_away_bin_feature_value_store[label])
    #     outcome_exp_away_value_matrix[i, j, k] = np.mean(output_away_bin_feature_value_store[label])
    #     away_num_matrix[i, j, k] = len(sample_away_bin_feature_value_store[label])
    #     predicted_std_away_value_matrix[i, j, k] = np.std(sample_away_bin_feature_value_store[label])
    #     outcome_std_away_value_matrix[i, j, k] = np.std(output_away_bin_feature_value_store[label])
    #
    # all_record_msg, normalized_home_empirical_risks, normalized_away_empirical_risks, \
    # normalized_home_std_risks, normalized_away_std_risks = \
    #     compute_distance_by_teams(context_feature_split_dict=interested_feature_split_dict,
    #                               sample_home_feature_label_meanings=sample_home_feature_label_meanings,
    #                               predicted_exp_home_value_matrix=predicted_exp_home_value_matrix,
    #                               outcome_exp_home_value_matrix=outcome_exp_home_value_matrix,
    #                               predicted_std_home_value_matrix=predicted_std_home_value_matrix,
    #                               outcome_std_home_value_matrix=outcome_std_home_value_matrix,
    #                               home_num_matrix=home_num_matrix,
    #                               sample_away_feature_label_meanings=sample_away_feature_label_meanings,
    #                               predicted_exp_away_value_matrix=predicted_exp_away_value_matrix,
    #                               outcome_exp_away_value_matrix=outcome_exp_away_value_matrix,
    #                               predicted_std_away_value_matrix=predicted_std_away_value_matrix,
    #                               outcome_std_away_value_matrix=outcome_std_away_value_matrix,
    #                               away_num_matrix=away_num_matrix,
    #                               distance_measure=distance_measure)
    #
    # if verbose:
    #     print(all_record_msg, file=agent.log_file, flush=True)
    # if record_file is not None:
    #     record_file.write(all_record_msg)


# return np.mean(normalized_home_empirical_risks + normalized_away_empirical_risks), \
#        np.mean(normalized_home_std_risks + normalized_away_std_risks)


def visualize_uncertainty_by_location(agent,
                                      model_label,
                                      debug_mode=False,
                                      episode_num=0,
                                      is_home=1,  # home: 1, away, 0
                                      target_action='shot',
                                      mode='train',
                                      uncertainty_model='gda',
                                      sanity_check_msg=None,
                                      if_plot_num=False,
                                      log_file=None,
                                      ):
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, valid_files, testing_files = divide_dataset_according2date(all_data_files=all_files,
                                                                               train_rate=agent.train_rate,
                                                                               sports=agent.sports,
                                                                               if_split=agent.apply_data_date_div
                                                                               )
    if not os.path.exists('./heatmaps_for_calibration/' + model_label):
        os.mkdir('./heatmaps_for_calibration/' + model_label)
    if mode == 'train':
        game_files = training_files
    elif mode == 'valid':
        game_files = valid_files
    else:
        game_files = testing_files

    if debug_mode:  # this is the debug mode
        game_files = game_files[:2]

    if agent.sports == 'ice-hockey':
        interested_location_feature_names = ('xAdjCoord', 'yAdjCoord')
    elif agent.sports == 'soccer':
        interested_location_feature_names = ('x', 'y')
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))

    action_samples_dict = {}
    for i in range(len(game_files)):
        # for i in tqdm(range(len(game_files)), desc="Testing uncertainty plot", file=log_file):
        # print(i)
        game_name = testing_files[i]
        action_samples_dict = \
            compute_action_output_with_features(agent=agent,
                                                game_name=game_name,
                                                is_home=is_home,
                                                action_samples_dict=action_samples_dict,
                                                sanity_check_msg=sanity_check_msg,
                                                output_type='Uncertainty',
                                                uncertainty_model=uncertainty_model,
                                                sports=agent.sports,
                                                interested_feature_names=interested_location_feature_names)

    sample_location_data = action_samples_dict[target_action]
    sample_uncertainty_data = [sample_location_data[j][0] for j in range(len(sample_location_data))]
    sample_uncertainty_data = np.stack(sample_uncertainty_data)
    sample_location_data = [sample_location_data[j][2] for j in range(len(sample_location_data))]
    sample_location_data = np.stack(sample_location_data)

    bin_size = 20
    bin_location_store_uncertainty, bin_expect_values_uncertainty, _, _, bin_num_values = \
        summarize_location_bin_samples(locations=sample_location_data,
                                       values=sample_uncertainty_data,
                                       bin_x=bin_size, bin_y=bin_size,
                                       num_tau=agent.num_tau,
                                       empirical_uncertainty_from_samples_flag=False)
    plot_heatmap(data_store=bin_expect_values_uncertainty[:, 5:],
                 plot_name='{2}/heat_map_{0}_uncertainty_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                     episode_num))
    if if_plot_num:
        plot_heatmap(data_store=bin_num_values[:, 5:],
                     plot_name='{2}/heat_map_{0}_num_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                 episode_num))


def visualize_distribution_plot(agent, is_home,
                                target_action, debug_mode, model_label,
                                sanity_check_msg, mode='test'):
    if not os.path.exists('./plot_example_imgs/' + model_label):
        os.mkdir('./plot_example_imgs/' + model_label)
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, valid_files, testing_files = divide_dataset_according2date(all_data_files=all_files,
                                                                               train_rate=agent.train_rate,
                                                                               sports=agent.sports,
                                                                               if_split=agent.apply_data_date_div
                                                                               )
    if mode == 'train':
        game_files = training_files
    elif mode == 'valid':
        game_files = valid_files
    else:
        game_files = testing_files

    if debug_mode:  # this is the debug mode
        game_files = game_files[:1]  # Game 16694, St. Louis Blues vs. Arizona Coyotes
    if agent.sports == 'ice-hockey':
        interested_location_feature_names = ('xAdjCoord', 'yAdjCoord')
    elif agent.sports == 'soccer':
        interested_location_feature_names = ('x', 'y')
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))
    action_samples_dict = {}
    for i in tqdm(range(len(game_files)), desc='Calculating game values.', file=agent.log_file):
        # for i in range(len(game_files)):
        game_name = game_files[i]
        action_samples_dict = \
            compute_action_output_with_features(agent=agent,
                                                game_name=game_name,
                                                is_home=is_home,
                                                action_samples_dict=action_samples_dict,
                                                sanity_check_msg=sanity_check_msg,
                                                interested_feature_names=interested_location_feature_names,
                                                sports=agent.sports)

    sample_location_data = action_samples_dict[target_action]

    if is_home:
        sample_Q_data = [sample_location_data[j][0][0] for j in range(len(sample_location_data))]
    else:
        sample_Q_data = [sample_location_data[j][0][1] for j in range(len(sample_location_data))]
    sample_location_data = [sample_location_data[j][2] for j in range(len(sample_location_data))]
    indexs = list(range(len(sample_Q_data)))
    # random.shuffle(indexs)
    # interested_record_idx = [2, 21, 22]
    interested_record_idx = [2, 21, 22, 35]
    for i in indexs:
        samples = sample_Q_data[i]
        location = sample_location_data[i]
        if_density = True
        img_save_path = './plot_example_imgs/' + \
                        model_label + \
                        '/{0}_idx_{1}_XCoord:{2}_YCoord:{3}.png'.format(
                            'distribution' if not if_density else 'density',
                            i,
                            round(location[0], 2),
                            round(location[1], 2))

        plot_histogram(samples=samples,
                       img_save_path=img_save_path,
                       location=location,
                       i=i,
                       add_cdf=if_density)


def evaluate_maf(agent, sanity_check_msg, debug_mode=False, log_file=None):
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, valid_files, testing_files = divide_dataset_according2date(all_data_files=all_files,
                                                                               train_rate=agent.train_rate,
                                                                               sports=agent.sports,
                                                                               if_split=agent.apply_data_date_div
                                                                               )
    loss_all = []
    log_prob_all = []
    agent.maf_model.eval()
    if debug_mode:  # this is the debug mode
        testing_files = testing_files[53:55]
    for file_name in testing_files:
        s_a_sequence, r_sequence = agent.load_sports_data(game_label=file_name,
                                                          sanity_check_msg=sanity_check_msg)
        pid_sequence = agent.load_player_id(game_label=file_name)
        if agent.apply_rnn:
            transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)
        else:
            transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                     r_data=r_sequence,
                                                     pid_sequence=pid_sequence)

        game_data = Transition(*zip(*transition_all))
        loss, log_prob = validate_maf(agent=agent,
                                      batch=game_data)
        loss_all += loss.tolist()
        log_prob_all += log_prob.tolist()
    print(log_prob_all[:20], file=log_file, flush=True)
    return np.mean(loss_all), np.mean(log_prob_all)


def visualize_exp_variance_by_location(agent,
                                       model_label,
                                       debug_mode=False,
                                       episode_num=0,
                                       is_home=1,  # home: 1, away, 0
                                       target_action='shot',
                                       mode='train',
                                       plot_advantage_heat_map=False,
                                       sanity_check_msg=None):
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, valid_files, testing_files = divide_dataset_according2date(all_data_files=all_files,
                                                                               train_rate=agent.train_rate,
                                                                               sports=agent.sports,
                                                                               if_split=agent.apply_data_date_div
                                                                               )
    if not os.path.exists('./heatmaps_for_calibration/' + model_label):
        os.mkdir('./heatmaps_for_calibration/' + model_label)
    if mode == 'train':
        game_files = training_files
    elif mode == 'valid':
        game_files = valid_files
    else:
        game_files = testing_files

    if debug_mode:  # this is the debug mode
        game_files = game_files[53:55]

    if agent.sports == 'ice-hockey':
        interested_location_feature_names = ('xAdjCoord', 'yAdjCoord')
    elif agent.sports == 'soccer':
        interested_location_feature_names = ('x', 'y')
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))

    action_samples_dict = {}
    action_outcome_dict = {}
    # for i in tqdm(range(len(game_files)), desc='Calculating game values.', file=agent.log_file):
    for i in range(len(game_files)):
        game_name = game_files[i]
        action_samples_dict = \
            compute_action_output_with_features(agent=agent,
                                                game_name=game_name,
                                                is_home=is_home,
                                                action_samples_dict=action_samples_dict,
                                                sanity_check_msg=sanity_check_msg,
                                                interested_feature_names=interested_location_feature_names
                                                )
        action_outcome_dict = empirical_action_output_with_features(agent=agent,
                                                                    game_name=game_name,
                                                                    is_home=is_home,
                                                                    action_outcome_dict=action_outcome_dict, )

    # var_location_data = action_var_dict[target_action]
    # mean_location_data = action_mean_dict[target_action]
    sample_location_data = action_samples_dict[target_action]
    outcome_location_data = action_outcome_dict[target_action]

    if is_home:
        # var_data = np.sqrt(var_location_data[:, 0])
        # mean_data = mean_location_data[:, 0]
        outcome_data = np.asarray(outcome_location_data)[:, 0]
        sample_Q_data = [sample_location_data[j][0][0] for j in range(len(sample_location_data))]
        sample_Q_data = np.stack(sample_Q_data)
        sample_advantage_data = [sample_location_data[j][1][0] for j in range(len(sample_location_data))]
        sample_advantage_data = np.stack(sample_advantage_data)
    else:
        # var_data = np.sqrt(var_location_data[:, 1])
        # mean_data = mean_location_data[:, 1]
        outcome_data = np.asarray(outcome_location_data[:, 1])
        sample_Q_data = [sample_location_data[j][0][1] for j in range(len(sample_location_data))]
        sample_Q_data = np.stack(sample_Q_data)
        sample_advantage_data = [sample_location_data[j][1][1] for j in range(len(sample_location_data))]
        sample_advantage_data = np.stack(sample_advantage_data)

    # var_location_data = var_location_data[:, -2:]
    # mean_location_data = mean_location_data[:, -2:]
    outcome_location_data = np.asarray(outcome_location_data)[:, -2:]
    sample_location_data = [sample_location_data[j][2] for j in range(len(sample_location_data))]
    sample_location_data = np.stack(sample_location_data)

    bin_size = 20
    bin_location_store_Q, bin_expect_values_Q, bin_std_values_Q, \
    bin_entropy_values_Q, _ = \
        summarize_location_bin_samples(locations=sample_location_data,
                                       values=sample_Q_data,
                                       bin_x=bin_size, bin_y=bin_size,
                                       num_tau=agent.num_tau)
    plot_heatmap(data_store=bin_expect_values_Q[:, 5:],
                 plot_name='{2}/heat_map_{0}_exp_Q_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                               episode_num))
    plot_heatmap(data_store=bin_std_values_Q[:, 5:],
                 plot_name='{2}/heat_map_{0}_std_Q_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                               episode_num))
    plot_heatmap(data_store=bin_entropy_values_Q[:, 5:],
                 plot_name='{2}/heat_map_{0}_entropy_Q_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                   episode_num))
    # plot_heatmap(data_store=bin_num_values_Q[:, 5:],
    #              plot_name='{2}/heat_map_{0}_num_Q_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
    #                                                                            episode_num))

    if plot_advantage_heat_map:
        bin_location_store_advantage, bin_expect_values_advantage, bin_std_values_advantage, \
        bin_entropy_values_advantage, bin_num_values_advantage = \
            summarize_location_bin_samples(locations=sample_location_data,
                                           values=sample_advantage_data,
                                           bin_x=bin_size, bin_y=bin_size,
                                           num_tau=agent.num_tau)
        plot_heatmap(data_store=bin_expect_values_advantage[:, 5:],
                     plot_name='{2}/heat_map_{0}_exp_advantage_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                           episode_num))
        plot_heatmap(data_store=bin_std_values_advantage[:, 5:],
                     plot_name='{2}/heat_map_{0}_std_advantage_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                           episode_num))
        plot_heatmap(data_store=bin_num_values_advantage[:, 5:],
                     plot_name='{2}/heat_map_{0}_num_advantage_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                           episode_num))

    bin_location_store_outcome, bin_expect_values_outcome, bin_std_values_outcome, \
    bin_entropy_values_outcome, bin_num_values_outcome = \
        summarize_location_bin_samples(locations=outcome_location_data,
                                       values=outcome_data,
                                       bin_x=bin_size, bin_y=bin_size,
                                       num_tau=agent.num_tau)
    plot_heatmap(data_store=bin_expect_values_outcome[:, 5:],
                 plot_name='{2}/heat_map_{0}_exp_outcome_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                     episode_num))
    plot_heatmap(data_store=bin_std_values_outcome[:, 5:],
                 plot_name='{2}/heat_map_{0}_std_outcome_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                     episode_num))
    plot_heatmap(data_store=bin_entropy_values_outcome[:, 5:],
                 plot_name='{2}/heat_map_{0}_entropy_outcome_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                         episode_num))
    plot_heatmap(data_store=bin_num_values_outcome[:, 5:],
                 plot_name='{2}/heat_map_{0}_num_outcome_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                     episode_num))

    bin_diff_expect = abs(bin_expect_values_outcome - bin_expect_values_Q)
    plot_heatmap(data_store=bin_diff_expect[:, 5:],
                 plot_name='{2}/heat_map_{0}_exp_diff_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                  episode_num))
    bin_std_values_outcome = normalization_01(bin_std_values_outcome)
    bin_std_values_Q = normalization_01(bin_std_values_Q)
    bin_diff_std = abs(bin_std_values_outcome - bin_std_values_Q)
    plot_heatmap(data_store=bin_diff_std[:, 5:],
                 plot_name='{2}/heat_map_{0}_std_diff_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                  episode_num))

    bin_entropy_values_outcome = normalization_01(bin_entropy_values_outcome)
    bin_entropy_values_Q = normalization_01(bin_entropy_values_Q)
    bin_diff_entropy = abs(bin_entropy_values_outcome - bin_entropy_values_Q)
    plot_heatmap(data_store=bin_diff_entropy[:, 5:],
                 plot_name='{2}/heat_map_{0}_entropy_diff_bin_{1}_{2}_epi-{3}'.format(mode, bin_size, model_label,
                                                                                      episode_num))

    # print(bin_diff_expect[:, 5:].flatten())
    # print(bin_num_values_outcome[:, 5:].flatten())

    correl_expect = np.corrcoef(bin_diff_expect[:, 5:].flatten(), bin_num_values_outcome[:, 5:].flatten())
    print(correl_expect, file=agent.log_file, flush=True)

    correl_std = np.corrcoef(bin_diff_std[:, 5:].flatten(), bin_num_values_outcome[:, 5:].flatten())
    print(correl_std, file=agent.log_file, flush=True)

    # tmp_value_2 = bin_expected_values[row_id, column_id]
    #
    # tmp_diff = tmp_value_1 - tmp_value_2
    #
    # print(tmp_diff)
    #
    # import matplotlib.pyplot as plt
    # for key in bin_location_store.keys():
    #     samples = bin_location_store[key]
    #     plt.figure()
    #     df = pd.DataFrame({"{0}".format(key): samples})
    #     ax = df.plot.hist(bins=12, alpha=0.5)
    #     plt.show()


# def compute_action_variance_by_location_all_game(agent,
#                                                  is_home=1,  # home: 1, away, 0
#                                                  target_action='shot'):
#     training_files = os.listdir(agent.train_data_path)
#     # print(training_files)
#     action_mean_dict = {}
#     action_var_dict = {}
#     action_samples_dict = {}
#     from tqdm import tqdm
#     for i in tqdm(range(len(training_files)), desc='Calculating game values.'):
#         game_name = training_files[i]
#         action_mean_dict, action_var_dict, _ = compute_action_variance_by_location(agent=agent,
#                                                                                    game_name=game_name,
#                                                                                    is_home=is_home,
#                                                                                    action_var_dict=action_var_dict,
#                                                                                    action_mean_dict=action_mean_dict,
#                                                                                    action_samples_dict=action_samples_dict)
#     var_location_data = action_var_dict[target_action]
#     mean_location_data = action_mean_dict[target_action]
#     if is_home:
#         var_data = np.sqrt(var_location_data[:, 0])
#         mean_data = mean_location_data[:, 0]
#     else:
#         var_data = np.sqrt(var_location_data[:, 1])
#         mean_data = mean_location_data[:, 1]
#
#     var_location_data = var_location_data[:, 3:]
#     mean_location_data = mean_location_data[:, 3:]
#
#     tmp = stats.pearsonr(x=mean_data, y=var_data)
#     print(tmp)
#
#     plot_scatter(x=var_location_data[:, 0], y=var_location_data[:, 1], z=var_data, label='Std', plot_name='tmp_std')
#     plot_scatter(x=mean_location_data[:, 0], y=mean_location_data[:, 1], z=mean_data, label='Exp', plot_name='tmp_exp')
#
#     bin_size = 5
#     bin_expected_values, bin_num_store = calculate_location_bin_expectation(locations=var_location_data,
#                                                                             values=var_data,
#                                                                             bin_x=bin_size, bin_y=bin_size)
#     plot_heatmap(data_store=bin_expected_values, plot_name='heat_map_exp_var_bin_{0}'.format(bin_size))
#     bin_size = 1
#     bin_expected_values, bin_num_store = calculate_location_bin_expectation(locations=var_location_data,
#                                                                             values=var_data,
#                                                                             bin_x=bin_size, bin_y=bin_size)
#     plot_heatmap(data_store=bin_expected_values, plot_name='heat_map_exp_var_bin_{0}'.format(bin_size))
#
#     bin_size = 5
#     bin_expected_values, bin_num_store = calculate_location_bin_expectation(locations=mean_location_data,
#                                                                             values=mean_data,
#                                                                             bin_x=bin_size, bin_y=bin_size)
#     plot_heatmap(data_store=bin_expected_values, plot_name='heat_map_exp_mean_bin_{0}'.format(bin_size))
#     bin_size = 1
#     bin_expected_values, bin_num_store = calculate_location_bin_expectation(locations=mean_location_data,
#                                                                             values=mean_data,
#                                                                             bin_x=bin_size, bin_y=bin_size)
#     plot_heatmap(data_store=bin_expected_values, plot_name='heat_map_exp_mean_bin_{0}'.format(bin_size))
#
#     return


def empirical_action_output_with_features(agent,
                                          game_name,
                                          is_home,
                                          sports,
                                          action_outcome_dict,
                                          interested_feature_names=('xAdjCoord', 'yAdjCoord'),
                                          ):
    if sports == 'ice-hockey':
        data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
    elif sports == 'soccer':
        data_means, data_stds = read_feature_mean_scale(data_dir='../soccer-data/')
    else:
        raise ValueError("Unknown sports: {0}".format(sports))
    pid_sequence = agent.load_player_id(game_label=game_name)
    s_a_sequence, r_sequence = agent.load_sports_data(game_label=game_name, need_check=False)
    if agent.apply_rnn:
        transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                     r_data=r_sequence,
                                                     pid_sequence=pid_sequence)
    else:
        transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                 r_data=r_sequence,
                                                 pid_sequence=pid_sequence)

    next_round_idx = float('inf')
    for i in range(len(transition_all)):

        if agent.apply_rnn:
            state_action_data = transition_all[i].state_action[transition_all[i].trace - 1]
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=sports)
            reward_h = transition_all[i].reward_h[transition_all[i].next_trace - 1]
            reward_a = transition_all[i].reward_a[transition_all[i].next_trace - 1]
            reward_n = transition_all[i].reward_n[transition_all[i].next_trace - 1]
        else:
            state_action_origin = reverse_standard_data(state_action_data=to_np(transition_all[i].state_action),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=sports)
            reward_h = transition_all[i].reward_h
            reward_a = transition_all[i].reward_a
            reward_n = transition_all[i].reward_n

        if i == next_round_idx - 1:  # action that scores the next goal, could be shot
            assert reward_h + reward_a + reward_n > 0

        # if begin_idx == i:  # time to update lookahead
        next_round_idx, h_cumu_rewards, a_cumu_rewards, n_cumu_rewards = \
            reward_look_ahead(transition_all, i, apply_rnn=agent.apply_rnn, gamma=agent.gamma)
        # print(reward_h_ahead, reward_a_ahead, reward_n_ahead)

        # if not if_sanity_check:
        if round(state_action_origin['home']) != is_home:
            continue

        action = None
        max_action_label = 0
        if sports == 'ice-hockey':
            sport_actions = ICEHOCKEY_ACTIONS
        elif sports == 'soccer':
            sport_actions = SOCCER_ACTIONS
        else:
            raise ValueError("Unknown sports: {0}".format(sports))
        for candidate_action in sport_actions:
            if state_action_origin[candidate_action] > max_action_label:
                max_action_label = state_action_origin[candidate_action]
                action = candidate_action
        interested_features = []
        for feature_name in interested_feature_names:
            interested_features.append(state_action_origin[feature_name])
        interested_features = np.asarray(interested_features)
        # location = np.asarray([state_action_origin['xAdjCoord'], state_action_origin['yAdjCoord']])
        outcome_location_event = np.concatenate(
            (np.asarray([h_cumu_rewards, a_cumu_rewards, n_cumu_rewards]), interested_features),
            axis=0)
        # outcome_location_event = np.expand_dims(outcome_location_event, axis=0)  # [num, 5 (r_h, r_a, r_e, x, y)]
        outcome_location_event = outcome_location_event  # [num, 5 (r_h, r_a, r_e, x, y)]

        if action in action_outcome_dict.keys():
            # action_outcome_dict[action] = np.concatenate((action_outcome_dict[action], outcome_location_event), axis=0)
            action_outcome_dict[action].append(outcome_location_event)
        else:
            action_outcome_dict.update({action: [outcome_location_event]})
    return action_outcome_dict


def compute_action_output_with_features(agent,
                                        game_name,
                                        is_home,
                                        action_samples_dict,
                                        sports,
                                        interested_feature_names,
                                        output_type='QValues',
                                        sanity_check_msg=None,
                                        uncertainty_model='gda'):
    if sports == 'ice-hockey':
        data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
    elif sports == 'soccer':
        data_means, data_stds = read_feature_mean_scale(data_dir='../soccer-data/')
    else:
        raise ValueError("Unknown sports: {0}".format(sports))
    if output_type == 'QValues':
        output, _ = agent.compute_values_by_game(game_name=game_name, sanity_check_msg=sanity_check_msg)
    elif output_type == 'Uncertainty':
        output, _ = agent.compute_uncertainty_by_game(game_name=game_name,
                                                      sanity_check_msg=sanity_check_msg,
                                                      use_home=is_home,
                                                      uncertainty_model=uncertainty_model)

    s_a_sequence, r_sequence = agent.load_sports_data(game_label=game_name, need_check=False)  # load all the features
    pid_sequence = agent.load_player_id(game_label=game_name)
    if agent.apply_rnn:
        transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                     r_data=r_sequence,
                                                     pid_sequence=pid_sequence)
    else:
        transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                 r_data=r_sequence,
                                                 pid_sequence=pid_sequence)

    # output_means = np.mean(output, axis=-1)
    # output_vars = np.var(output, axis=-1)

    for i in range(len(transition_all)):
        if agent.apply_rnn:
            state_action_data = transition_all[i].state_action[transition_all[i].trace - 1]
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=sports)
            # sanity_check_msg=sanity_check_msg)
        else:
            state_action_origin = reverse_standard_data(state_action_data=to_np(transition_all[i].state_action),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=sports)
            # sanity_check_msg=sanity_check_msg)
        # print(state_action_origin)
        # if not if_sanity_check:
        if round(state_action_origin['home']) != is_home:
            continue

        action = None
        max_action_label = 0
        if sports == 'ice-hockey':
            sport_actions = ICEHOCKEY_ACTIONS
        elif sports == 'soccer':
            sport_actions = SOCCER_ACTIONS
        else:
            raise ValueError("Unknown sports: {0}".format(sports))
        for candidate_action in sport_actions:  # check which action is performed
            if state_action_origin[candidate_action] > max_action_label:
                max_action_label = state_action_origin[candidate_action]
                action = candidate_action
        interested_features = []
        for feature_name in interested_feature_names:
            interested_features.append(state_action_origin[feature_name])
        interested_features = np.asarray(interested_features)
        # mean_location_event = np.concatenate((output_means[i], location), axis=0)
        # mean_location_event = np.expand_dims(mean_location_event, axis=0)
        # var_location_event = np.concatenate((output_vars[i], location), axis=0)
        # var_location_event = np.expand_dims(var_location_event, axis=0)
        if action == 'goal' or i == 0:
            advantages = np.zeros_like(output[0]).squeeze()
        else:
            advantages = output[i] - output[i - 1]
        sample_location_event = [output[i], advantages, interested_features]
        if action in action_samples_dict.keys():
            # action_var_dict[action] = np.concatenate((action_var_dict[action], var_location_event), axis=0)
            # action_mean_dict[action] = np.concatenate((action_mean_dict[action], mean_location_event), axis=0)
            action_samples_dict[action].append(sample_location_event)
        else:
            # action_var_dict.update({action: var_location_event})
            # action_mean_dict.update({action: mean_location_event})
            action_samples_dict.update({action: [sample_location_event]})
    return action_samples_dict


def compute_action_variance_all_game(agent, sanity_check_msg):
    training_files = os.listdir(agent.train_data_path)
    # print(training_files)
    action_var_dict = {}
    for game_name in training_files:
        print(game_name)
        action_var_dict = compute_action_variance(agent, game_name, sanity_check_msg,
                                                  action_var_dict=action_var_dict)

    avg_action_std_dict = {}
    for key in action_var_dict.keys():
        # print('{0}: {1}'.format(key, np.mean(action_var_dict[key])))
        avg_action_std_dict.update({key: np.mean(np.sqrt(action_var_dict[key]))})
    for k, v in sorted(avg_action_std_dict.items(), key=lambda p: p[1], reverse=True):
        print(k, v)


def compute_action_variance(agent, game_name, sanity_check_msg, action_var_dict={}):
    if agent.sports == 'ice-hockey':
        data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
    elif agent.sports == 'soccer':
        data_means, data_stds = read_feature_mean_scale(data_dir='../soccer-data/')
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))
    output, transition_all = agent.compute_values_by_game(game_name, sanity_check_msg)

    output_means = np.mean(output, axis=-1)
    output_vars = np.var(output, axis=-1)

    for i in range(len(transition_all)):
        var_event = output_vars[i]
        if agent.apply_rnn:
            state_action_data = transition_all[i].state_action[transition_all[i].trace - 1]
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=agent.sports)
        else:
            state_action_origin = reverse_standard_data(state_action_data=to_np(transition_all[i].state_action),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sports=agent.sports)
        action = None
        max_action_label = 0
        for candidate_action in ICEHOCKEY_ACTIONS:
            if state_action_origin[candidate_action] > max_action_label:
                max_action_label = state_action_origin[candidate_action]
                action = candidate_action
        var_event = np.expand_dims(var_event, axis=0)
        if action in action_var_dict.keys():
            action_var_dict[action] = np.concatenate((action_var_dict[action], var_event), axis=0)
        else:
            action_var_dict.update({action: var_event})

    return action_var_dict


def generate_game_plot(agent, game_name, episode_num,
                       sanity_check_msg, date_label, alpha='mean', debug_msg='',
                       plot_save_path=None):
    game_time_all, transition_game, output_game = get_game_time(agent=agent, game_name=game_name,
                                                                sanity_check_msg=sanity_check_msg)

    if agent.all_gda_models is not None:
        uncertainties_game, _ = agent.compute_uncertainty_by_game(game_name=game_name,
                                                                  sanity_check_msg=sanity_check_msg,
                                                                  transition_game=transition_game,
                                                                  )
    else:
        uncertainties_game = None

    if 'distrib' in agent.task:
        output_mean_all = np.mean(output_game, axis=-1)
        if alpha == 'mean':
            plot_value_all = output_mean_all
        else:
            risk_idx = int(agent.num_tau * alpha)
            plot_value_all = output_game[:, :, risk_idx]
    else:
        plot_value_all = output_game
        output_mean_all = output_game
        uncertainties_game = None
        mean_game_std, max_game_std, min_game_std = None, None, None

    # output_entropy_all = samples2entropy(output_all)
    output_var_all = np.var(output_game, axis=-1)

    if plot_save_path is None:
        plot_save_path = './figures_for_evaluation/{0}/'.format(
            'distrib_dqn' if agent.task == 'train_distrib_rl' else 'dqn') \
                         + get_distrib_q_model_save_path(agent, date_label, debug_msg).split('/')[-1]. \
                             replace('saved', 'game-{0}'.format(game_name)) \
                         + '_episode-{0}'.format(episode_num)

    event_num = [i for i in range(output_game.shape[0])]
    # print(rewards_by_team['home'].index(1))
    # print(rewards_by_team['away'].index(1))
    # print(np.sum(rewards_by_team['away']))
    if agent.sports == 'ice-hockey':
        plot_start_idx = 3200
        plot_end_idx = 4000
    elif agent.sports == 'soccer':
        plot_start_idx = 0
        plot_end_idx = 2000
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))

    plot_curve(draw_keys=['home', 'away'],
               x_dict={'home': event_num, 'away': event_num, 'end': event_num},
               y_dict={'home': plot_value_all[:, 0],
                       'away': plot_value_all[:, 1],
                       'end': plot_value_all[:, 2]},
               img_size=(7, 6),
               linewidth=2,
               xlabel='event_number',
               ylabel='Q values',
               plot_name=plot_save_path)

    if 'distrib' in agent.task:
        # output_mean_all = np.concatenate(output_mean_all, axis=0)
        # output_std_all = np.sqrt(np.concatenate(output_var_all, axis=0))
        output_std_all = np.sqrt(output_var_all)
        mean_game_std = np.mean(output_std_all)
        max_game_std = np.max(output_std_all)
        min_game_std = np.min(output_std_all)
        # print("std mean:{0}, max:{1} and min: {2}".format(mean_game_std, max_game_std, min_game_std))
        plot_shadow_curve(draw_keys=['Home', 'Away'],
                          x_dict_mean={'Home': game_time_all[plot_start_idx:plot_end_idx],
                                       'Away': game_time_all[plot_start_idx:plot_end_idx],
                                       'Neither': game_time_all[plot_start_idx:plot_end_idx]},
                          y_dict_mean={'Home': output_mean_all[:, 0][plot_start_idx:plot_end_idx],
                                       'Away': output_mean_all[:, 1][plot_start_idx:plot_end_idx],
                                       'Neither': output_mean_all[:, 2][plot_start_idx:plot_end_idx]},
                          x_dict_std={'Home': game_time_all[plot_start_idx:plot_end_idx],
                                      'Away': game_time_all[plot_start_idx:plot_end_idx],
                                      'Neither': game_time_all[plot_start_idx:plot_end_idx]},
                          y_dict_std={'Home': output_std_all[:, 0][plot_start_idx:plot_end_idx],
                                      'Away': output_std_all[:, 1][plot_start_idx:plot_end_idx],
                                      'Neither': output_std_all[:, 2][plot_start_idx:plot_end_idx]},
                          linewidth=1,
                          xlabel='Game Time (in Seconds)',
                          ylabel='Action-Values',
                          plot_name=plot_save_path,
                          linestyle_dict={'Home': '-', 'Away': '--', 'Neither': '-.'},
                          img_size=(9, 5),
                          ylim=(0, 1))

    # return None
    # if episode_num == 'testing':
    #     print_game_events_info(transition_game=transition_game[plot_start_idx:plot_end_idx],
    #                            team_values_all=output_mean_all[plot_start_idx:plot_end_idx],
    #                            apply_rnn=agent.apply_rnn,
    #                            team_uncertainties_all=uncertainties_game[plot_start_idx:plot_end_idx]
    #                            if uncertainties_game is not None else None)

    return mean_game_std, max_game_std, min_game_std


def sanity_check_by_goal(agent, training_files, game_num, sanity_check_msg):
    data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')

    for game_name in training_files[-game_num:]:
        print(game_name)
        output, transition_all = agent.compute_values_by_game(game_name, sanity_check_msg)

        output_means = np.mean(output, axis=-1)

        for i in range(len(transition_all)):
            mean_event = output_means[i]
            if agent.apply_rnn:
                state_action_data = transition_all[i].state_action[transition_all[i].trace - 1]
                state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                            data_means=data_means,
                                                            data_stds=data_stds,
                                                            sports=agent.sports)
                reward_h = transition_all[i].reward_h[transition_all[i].next_trace - 1]
                reward_a = transition_all[i].reward_a[transition_all[i].next_trace - 1]
                reward_n = transition_all[i].reward_n[transition_all[i].next_trace - 1]
            else:
                state_action_origin = reverse_standard_data(state_action_data=to_np(transition_all[i].state_action),
                                                            data_means=data_means,
                                                            data_stds=data_stds,
                                                            sports=agent.sports)
                reward_h = transition_all[i].reward_h
                reward_a = transition_all[i].reward_a
                reward_n = transition_all[i].reward_n

            action = None
            max_action_label = 0
            for candidate_action in ICEHOCKEY_ACTIONS:
                if state_action_origin[candidate_action] > max_action_label:
                    max_action_label = state_action_origin[candidate_action]
                    action = candidate_action
            # print(action)
            if reward_h + reward_a + reward_n > 0:
                print("find you")
            if i == len(transition_all) - 1:
                print(transition_all[-1].reward_n)
                print(transition_all[-1].next_trace)


def plot_game_quantiles_values(agent,
                               game_name,
                               is_home,
                               sanity_check_msg,
                               action_samples_dict={},
                               action_only=False,
                               target_action='shot'):
    import matplotlib.pyplot as plt
    plot_action = []
    data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
    output, transition_all = agent.compute_values_by_game(game_name, sanity_check_msg)
    for i in range(len(transition_all)):
        if agent.apply_rnn:
            state_action_data = transition_all[i].state_action[transition_all[i].trace - 1]
            state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sanity_check_msg=sanity_check_msg,
                                                        sports=agent.sports)
        else:
            state_action_origin = reverse_standard_data(state_action_data=to_np(transition_all[i].state_action),
                                                        data_means=data_means,
                                                        data_stds=data_stds,
                                                        sanity_check_msg=sanity_check_msg,
                                                        sports=agent.sports)
        if not action_only and int(state_action_origin['home']) != is_home:
            continue

        action = None
        max_action_label = 0
        for candidate_action in ICEHOCKEY_ACTIONS:
            if state_action_origin[candidate_action] > max_action_label:
                max_action_label = state_action_origin[candidate_action]
                action = candidate_action

        if action in action_samples_dict.keys():
            action_samples_dict[action].append(output[i])
        else:
            action_samples_dict.update({action: [output[i]]})

        if target_action in action:
            plt.figure()
            df = pd.DataFrame({"{0}".format(action): output[i][0] if is_home else output[i][1]})
            ax = df.plot.hist(bins=20, alpha=0.5)
            plt.show()
            plot_action.append(action)


def evaluate_playing_style(agent, debug_mode, target_action='shot', target_team_id='18'):
    if agent.sports == 'ice-hockey':
        data_means, data_stds = read_feature_mean_scale(data_dir='../icehockey-data/')
    elif agent.sports == 'soccer':
        data_means, data_stds = read_feature_mean_scale(data_dir='../soccer-data/')
    else:
        raise ValueError("Unknown sports: {0}".format(agent.sports))

    player_info = read_player_info()

    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, valid_files, testing_files = \
        divide_dataset_according2date(all_data_files=all_files,
                                      train_rate=agent.train_rate,
                                      sports=agent.sports,
                                      if_split=agent.apply_data_date_div)
    game_files = all_files
    if debug_mode:
        game_files = game_files
    all_player_ids_value_predictions_dict = {}
    for game_name in game_files:
        _, _, team_ids = read_data(source_data_dir=agent.source_data_dir,
                                   file_name=game_name + '-playsequence-wpoi.json',
                                   output_team_ids=True)
        if target_team_id == team_ids[0]:
            target_h_a = 'away'
        elif target_team_id == team_ids[1]:
            target_h_a = 'home'
        else:
            continue
        print('handling game {0}'.format(game_name))
        output_game, transition_game = agent.compute_values_by_game(game_name=game_name,
                                                                    sanity_check_msg=None)
        player_ids = agent.load_player_id(game_label=game_name)
        for i in range(len(transition_game)):
            if agent.apply_rnn:
                state_action_data = transition_game[i].state_action[transition_game[i].trace - 1]
                state_action_origin = reverse_standard_data(state_action_data=to_np(state_action_data),
                                                            data_means=data_means,
                                                            data_stds=data_stds,
                                                            sports=agent.sports)
            else:
                state_action_origin = reverse_standard_data(state_action_data=to_np(transition_game[i].state_action),
                                                            data_means=data_means,
                                                            data_stds=data_stds,
                                                            sports=agent.sports)
            event_h_a = 'home' if state_action_origin['home'] > state_action_origin['away'] else 'away'
            if event_h_a != target_h_a:
                continue

            action = None
            max_action_label = 0
            if agent.sports == 'ice-hockey':
                sport_actions = ICEHOCKEY_ACTIONS
            elif agent.sports == 'soccer':
                sport_actions = SOCCER_ACTIONS
            else:
                raise ValueError("Unknown sports: {0}".format(agent.sports))
            for candidate_action in sport_actions:  # check which action is performed
                if state_action_origin[candidate_action] > max_action_label:
                    max_action_label = state_action_origin[candidate_action]
                    action = candidate_action
            if target_action not in action:
                continue
            value_idx = 0 if event_h_a == 'home' else 1
            if player_ids[i] in all_player_ids_value_predictions_dict.keys():
                all_player_ids_value_predictions_dict[player_ids[i]].append(output_game[i][value_idx])
            else:
                all_player_ids_value_predictions_dict.update({player_ids[i]: [output_game[i][value_idx]]})

    plot_value_predictions = []
    plot_pids = []
    marklist = sorted([(len(value), pid) for (pid, value) in all_player_ids_value_predictions_dict.items()], reverse=True)
    sorted_pids = [item[1] for item in marklist]
    for pid in sorted_pids[:5]:
        plot_value_predictions += all_player_ids_value_predictions_dict[pid]
        plot_pids += [pid] * len(all_player_ids_value_predictions_dict[pid])
    plot_value_predictions = np.asarray(plot_value_predictions)
    tmp1 = plot_value_predictions.mean(axis=1)
    tmp2 = plot_value_predictions.var(axis=1)
    plot_value_predictions = np.stack([tmp1, tmp2], axis=-1)
    plot_embeddings = TSNE(n_components=2, perplexity=5.0).fit_transform(plot_value_predictions)
    # plot_embeddings = plot_value_predictions

    plot_player_ids_value_predictions_dict = {}
    for idx in range(len(plot_pids)):
        if plot_pids[idx] in plot_player_ids_value_predictions_dict.keys():
            plot_player_ids_value_predictions_dict[plot_pids[idx]].append(plot_embeddings[idx])
        else:
            plot_player_ids_value_predictions_dict.update({plot_pids[idx]: [plot_embeddings[idx]]})

    dr_embeddings = []
    labels = []
    for pid in plot_player_ids_value_predictions_dict.keys():
        dr_embeddings.append(np.asarray(plot_player_ids_value_predictions_dict[pid]))
        labels.append('-'.join(player_info[pid]))

    plot_scatter(scatter_data=dr_embeddings,
                 labels=labels,
                 plot_name='scatter_quantiles_pid',
                 xlim=None,
                 ylim=None)
    # print('still working')
